{
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10-final"
  },
  "orig_nbformat": 2,
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2,
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 线性回归\n",
    "PyTorch 深度学习的基本技能已经掌握，从本篇开始进入深度学习模型，任何模型都具备以下基本要素，弄懂一个模型非常简单\n",
    "\n",
    "+ 数据：数据就是输入和已知的输出【监督学习可能没有已知输出】\n",
    "+ 模型：模型就是数学中的一个函数镞，可变化的就是参数，输入到输出的一个映射，一般输出都是一个具体的值，可以是概率和数值\n",
    "+ 参数：一个函数必然包含固定个数或者很多参数\n",
    "+ 损失函数：对任意一个模型，我们总能找到一个办法衡量它的好坏，这个办法就是通过损失函数\n",
    "+ 优化方法：对一个固定的损失函数，我们通过优化算法，就能求出模型的参数\n",
    "\n",
    "\n",
    "上面每一个概念在 PyTorch 中都有一个对应方式    \n",
    "以后我们统一 y 表示输出，x1，x2，..... 表示单个输入变量，X 表示输入向量，一个样本\n",
    "\n",
    "## 线性回归模型\n",
    "线性回归模型怎么对应上面的几个基本构成呢\n",
    "+ 数据：根据问题来的，比如预测某个地区的房价，必然会收集很多房子的特征：面积、楼龄、周边商超数量、卧室数量等，还会收集它现在的价格，这个就是数据\n",
    "+ 模型：线性回归就是规定函数形式就是：p = b + a1 \\* x1 + a2 \\* x2 + an * xm, b 是 bias\n",
    "+ 参数：b a1 a2 am 就是参数\n",
    "+ 损失函数：z = sum(sum((y-p) * (y-p))/m)/n，其中 m 是参数个数，n 是样本数量\n",
    "+ 优化方法：梯度下降等，得益于 PyTorch 自动求梯度，我们不再需要显示求出损失函数的导数\n",
    "\n",
    "当模型和损失函数形式较为简单时，上面的误差最小化问题的解可以直接用公式表达出来。这类解叫作解析解（analytical solution）。本节使用的线性回归和平方误差刚好属于这个范畴。然而，大多数深度学习模型并没有解析解，只能通过优化算法有限次迭代模型参数来尽可能降低损失函数的值。这类解叫作数值解（numerical solution）。\n",
    "\n",
    "## 优化方法\n",
    "在求数值解的优化算法中，小批量随机梯度下降（mini-batch stochastic gradient descent）在深度学习中被广泛使用。它的算法很简单：先选取一组模型参数的初始值，如随机选取；接下来对参数进行多次迭代，使每次迭代都可能降低损失函数的值。在每次迭代中，先随机均匀采样一个由固定数目训练数据样本所组成的小批量（mini-batch）[Math Processing Error]B，然后求小批量中数据样本的平均损失有关模型参数的导数（梯度），最后用此结果与预先设定的一个正数的乘积作为模型参数在本次迭代的减小量。\n",
    "\n",
    "下面我们来用 PyTorch 来构建整个线性回归模型。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. 数据准备\n",
    "下面简单模拟一些测试数据\n",
    "+ 假设总共两个变量，变量对应的线性权重为：[2, -3.4]\n",
    "+ 假设bias = 4.2\n",
    "+ 假设误差服从均值为 0，方差为 0.1 的正态分布\n",
    "+ 假设样本个数为：1000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "X.shape= torch.Size([1000, 2])\ntrue_weights.shape torch.Size([2, 1])\nerror.shape torch.Size([1000])\nY.shape torch.Size([1000])\ntensor([3.8446, 4.0907, 4.1172, 3.3592, 5.2081, 2.7656, 5.2604, 2.8241, 2.4092,\n        2.1524, 5.2988, 1.7538, 4.8105, 1.9225, 3.7813, 4.3766, 4.1147, 3.0951,\n        4.5631, 3.5313, 4.3096, 2.9987, 2.5910, 1.9592, 3.1709, 3.9261, 4.9781,\n        4.0535, 2.6952, 6.1048, 3.6954, 3.6510, 3.0196, 5.7960, 4.4282, 2.5839,\n        3.0691, 4.9949, 2.3908, 3.5015, 3.6174, 5.2803, 3.9719, 5.3797, 4.2355,\n        4.4568, 2.3278, 4.5102, 2.8263, 3.0402, 3.8340, 4.2123, 2.1433, 5.4706,\n        3.0473, 3.7451, 3.7015, 4.2973, 4.2556, 2.6841, 2.4219, 2.6268, 3.1156,\n        3.2309, 4.6750, 3.0694, 4.9669, 2.5136, 1.6487, 3.6050, 2.9262, 4.2725,\n        4.7486, 5.4740, 3.4866, 3.1768, 2.2584, 2.8020, 1.7936, 5.4785, 4.8686,\n        3.1204, 2.2605, 3.3992, 3.0475, 3.2269, 3.1219, 3.0383, 5.0227, 3.4884,\n        4.2764, 2.6713, 3.5684, 3.4632, 2.2490, 1.4212, 4.8428, 4.5983, 5.7281,\n        1.1053, 1.3899, 2.6281, 4.4704, 3.0867, 3.2902, 1.9691, 3.3374, 2.5070,\n        3.0917, 4.1047, 3.8225, 4.3696, 4.3819, 1.4250, 1.7278, 4.0682, 2.1029,\n        4.6581, 2.6579, 1.0842, 3.2255, 4.1708, 2.9084, 5.0104, 5.3007, 4.1932,\n        2.7696, 5.2778, 2.4161, 4.5281, 3.6338, 4.6920, 3.0136, 5.4588, 3.1818,\n        1.7138, 2.8399, 5.5421, 4.9831, 4.4050, 4.0265, 3.9643, 3.3140, 2.4577,\n        2.8055, 2.7722, 3.5349, 4.7602, 4.1784, 2.6367, 5.2617, 4.1096, 2.4122,\n        1.2558, 3.7779, 3.6719, 4.0400, 1.3096, 6.0613, 3.2657, 4.0496, 4.9244,\n        2.8002, 3.4794, 2.2826, 5.4166, 2.7551, 2.8156, 2.3018, 3.3694, 2.7911,\n        2.5651, 5.9932, 5.0861, 5.2977, 3.8613, 4.3364, 4.4534, 4.5296, 4.8354,\n        4.1069, 4.7309, 4.2550, 2.5269, 2.3429, 2.0414, 2.4935, 1.1287, 3.0971,\n        1.2307, 5.5333, 3.4166, 2.8524, 2.3041, 4.9462, 1.4070, 4.0568, 2.3239,\n        1.7112, 2.0723, 4.2905, 3.7974, 3.8739, 3.7477, 4.3236, 1.5957, 1.1051,\n        4.3411, 1.8666, 3.0720, 4.8836, 4.5631, 4.6138, 3.6825, 3.7021, 1.4954,\n        2.7708, 2.7047, 3.4470, 2.5475, 1.4418, 4.9161, 1.7931, 1.1804, 3.9426,\n        2.8327, 5.0788, 2.6199, 1.7021, 5.3544, 2.1901, 5.9561, 3.0831, 4.1852,\n        4.3918, 4.1454, 5.7075, 1.8844, 3.8927, 2.5753, 3.6926, 2.9570, 5.3129,\n        2.2550, 1.3898, 1.2307, 2.7502, 4.2561, 3.7416, 2.3882, 4.4464, 5.1021,\n        3.6374, 5.6678, 1.1898, 4.4163, 4.0352, 3.0646, 2.3684, 6.1629, 2.5762,\n        3.0567, 1.5983, 5.0657, 2.9574, 5.4633, 3.2944, 1.5581, 1.8404, 2.4713,\n        3.4160, 4.6793, 4.5125, 4.1209, 4.3137, 3.4236, 3.8988, 2.7881, 3.1954,\n        2.1906, 1.6602, 2.2132, 2.8204, 3.5093, 1.5236, 3.9393, 5.1923, 2.3317,\n        3.1636, 4.6599, 3.6859, 1.6121, 4.0461, 2.7652, 3.3690, 3.9002, 4.5103,\n        4.4520, 1.9180, 4.2729, 2.5159, 2.8345, 3.4961, 3.4104, 3.0264, 3.9921,\n        2.6182, 5.3045, 3.3655, 3.3694, 2.6716, 1.3454, 3.8676, 3.0175, 5.1206,\n        3.0445, 4.3336, 2.7594, 4.1684, 3.4109, 4.1006, 3.4700, 3.3206, 2.5965,\n        2.8328, 4.0561, 2.2106, 4.3804, 2.3170, 2.8446, 4.0964, 6.0516, 4.8875,\n        2.1912, 3.7423, 2.9094, 5.8844, 4.6662, 4.3280, 4.2454, 4.6310, 4.0153,\n        3.9877, 1.4614, 3.6621, 1.8053, 5.0628, 3.7873, 3.5244, 3.8451, 2.3934,\n        3.8147, 4.1964, 5.2099, 3.3625, 2.9344, 3.6468, 5.4720, 3.7515, 2.6382,\n        3.9963, 4.1561, 1.9310, 3.9945, 2.2824, 5.1472, 3.9641, 5.4321, 1.4979,\n        3.5721, 2.5182, 5.5424, 2.1057, 5.6653, 4.0678, 5.7093, 4.5317, 2.0919,\n        4.6296, 4.3642, 2.6222, 3.2959, 1.4941, 3.5131, 3.8708, 4.5390, 2.9980,\n        4.4479, 3.3798, 3.7509, 4.8939, 2.0588, 3.2340, 5.0307, 4.2699, 3.0631,\n        1.3911, 4.0911, 4.1031, 2.8581, 4.0844, 1.0670, 2.6033, 2.2940, 1.6506,\n        3.5137, 3.8849, 2.8953, 3.1624, 2.8559, 2.8540, 4.0983, 3.9319, 2.3805,\n        4.6426, 1.8018, 2.8803, 3.6204, 1.9513, 3.3372, 3.5653, 3.7454, 3.9115,\n        2.6576, 2.6302, 3.5995, 3.4637, 1.9407, 4.0688, 2.3407, 4.7637, 1.6325,\n        2.0965, 3.1685, 3.3711, 6.4365, 4.7217, 4.5149, 4.3602, 1.4564, 3.4634,\n        5.4605, 1.4952, 4.0433, 4.1261, 3.2133, 4.6284, 1.8758, 3.8899, 3.8323,\n        2.8465, 3.8606, 3.9214, 5.3985, 5.0501, 3.7157, 3.3969, 4.6939, 2.4360,\n        5.2907, 1.1016, 4.3594, 3.0314, 4.0438, 1.7346, 3.4324, 4.1993, 1.6376,\n        4.2986, 4.0750, 2.2012, 4.4705, 2.5391, 2.9143, 2.4305, 4.9783, 2.7799,\n        1.9568, 1.4383, 3.6168, 4.6512, 2.0329, 4.5920, 3.7189, 2.4136, 1.1297,\n        2.5500, 4.2853, 2.4182, 2.9808, 1.9679, 5.3402, 4.2904, 2.5738, 2.8753,\n        2.9890, 2.9681, 3.2715, 3.5154, 5.9315, 3.1639, 3.6384, 4.4139, 2.4843,\n        3.2656, 2.6114, 3.1534, 1.6492, 4.2354, 1.7018, 2.6883, 2.8166, 3.5043,\n        3.0860, 4.6004, 3.4558, 4.8038, 3.9778, 2.9297, 2.1587, 3.9946, 2.5446,\n        4.0538, 4.8665, 3.4590, 1.7193, 4.3128, 2.9730, 4.8749, 4.2817, 2.1314,\n        3.0608, 4.7554, 2.3700, 5.3854, 4.1466, 2.5890, 3.6023, 3.5721, 3.1005,\n        2.5122, 5.2542, 4.1164, 4.3411, 4.9667, 3.1611, 1.8987, 3.3484, 3.0425,\n        4.7172, 2.2071, 6.0121, 3.3451, 5.1327, 5.0746, 2.7375, 4.1788, 4.0787,\n        3.3369, 4.7302, 3.0763, 4.8329, 4.6183, 4.1865, 0.9656, 6.1165, 3.0259,\n        1.7801, 3.4640, 3.2774, 4.1853, 1.6130, 5.1586, 5.7576, 4.8708, 4.8904,\n        2.1909, 3.3749, 4.6903, 3.4301, 4.9998, 3.1127, 5.6577, 4.4222, 1.7757,\n        3.2485, 1.9961, 2.3486, 3.8529, 2.7198, 4.0464, 3.3519, 1.7656, 4.4008,\n        5.4976, 3.8522, 2.2879, 1.1546, 2.2022, 4.4012, 3.7354, 1.6696, 3.5627,\n        1.5460, 4.6488, 4.6712, 2.4939, 4.2322, 2.4766, 3.8027, 4.3004, 4.1955,\n        4.5342, 4.5071, 3.0329, 2.5821, 4.5070, 4.7934, 4.4442, 4.0621, 2.9207,\n        2.8465, 4.2419, 4.9268, 4.4527, 2.9298, 4.6919, 5.0903, 4.6572, 2.1507,\n        2.6185, 2.5054, 2.4475, 3.0125, 3.8792, 2.0127, 5.5918, 3.1520, 3.7271,\n        4.2969, 4.0784, 4.6646, 1.7564, 2.5421, 2.0721, 3.7527, 4.1981, 3.9840,\n        4.1955, 2.4768, 4.7183, 2.6201, 4.0921, 3.2859, 1.7844, 2.5540, 4.2687,\n        3.1899, 4.5519, 4.3004, 2.6197, 5.2841, 3.1591, 3.0768, 5.8369, 3.8521,\n        2.9292, 3.4007, 3.9730, 4.8677, 4.9375, 2.1241, 4.6506, 4.1671, 4.6104,\n        2.9725, 2.6816, 3.5463, 3.2824, 3.3844, 3.6954, 3.4102, 4.9265, 1.5247,\n        3.3365, 3.0972, 4.7498, 4.6413, 5.0101, 2.2056, 2.8666, 2.4416, 2.6975,\n        3.2527, 5.0393, 2.2849, 2.6048, 4.1109, 2.4812, 2.9607, 2.8900, 3.0997,\n        3.3855, 3.5021, 2.3622, 4.0371, 5.1019, 3.5159, 2.2130, 3.6075, 5.1245,\n        5.6870, 4.5442, 2.6910, 3.7526, 4.8621, 1.0371, 4.0803, 3.2929, 5.7485,\n        1.1180, 3.3530, 1.5896, 4.1022, 1.9216, 1.7424, 3.6737, 4.3372, 3.2322,\n        3.0616, 3.6858, 4.9889, 5.2413, 2.9448, 4.4059, 2.1860, 4.6691, 1.4255,\n        2.9349, 3.4576, 3.6847, 5.4328, 2.7601, 4.3654, 4.0444, 4.1013, 4.0687,\n        3.4711, 2.4279, 3.1631, 2.4846, 2.6946, 2.9483, 4.9985, 3.9978, 4.3447,\n        2.9443, 3.8225, 4.8873, 4.3667, 4.6956, 1.8895, 4.4165, 4.5875, 3.3168,\n        3.0381, 2.8014, 4.8542, 3.8810, 4.6255, 3.4070, 4.3673, 1.1123, 4.1492,\n        4.2641, 3.0680, 2.2974, 5.9774, 4.1215, 4.0750, 2.6316, 5.1703, 5.6843,\n        2.5828, 2.6558, 2.4168, 3.9511, 2.9365, 2.5033, 1.6729, 4.9357, 2.5544,\n        3.7153, 5.4120, 2.9207, 4.7157, 1.8796, 2.3960, 5.2531, 5.3573, 4.6392,\n        4.3044, 4.8420, 3.1664, 4.7714, 1.4301, 3.0048, 4.5645, 3.0313, 5.3199,\n        2.4873, 4.5625, 4.6951, 5.4823, 4.4926, 3.7380, 2.7199, 4.5471, 2.6159,\n        3.0011, 3.7241, 2.1079, 2.8157, 3.1352, 5.0903, 4.2189, 2.8857, 3.3773,\n        1.9229, 3.4108, 5.8048, 5.3316, 3.0709, 5.8118, 1.8513, 1.5707, 2.1389,\n        4.6018, 5.2672, 2.8558, 3.1922, 3.7917, 4.9414, 4.0275, 3.2594, 2.9054,\n        3.8022, 1.2948, 1.3668, 4.3320, 2.3055, 4.0324, 2.9293, 5.1496, 5.0543,\n        3.5165, 5.1459, 5.1967, 4.2998, 4.6244, 4.6941, 2.8457, 3.0171, 4.2510,\n        4.2386, 1.2601, 1.9886, 3.5795, 3.8134, 3.9686, 3.1661, 3.8266, 4.4885,\n        3.2984, 3.0298, 4.1166, 2.3968, 4.5340, 5.0519, 3.5300, 2.7145, 3.7158,\n        3.6781, 1.9388, 4.3390, 4.9633, 5.2593, 5.5802, 4.2703, 3.8694, 5.0108,\n        4.6149, 3.0081, 3.1337, 1.8906, 2.2507, 4.9301, 3.4584, 4.5505, 4.6651,\n        4.0257, 4.1322, 4.3152, 3.6512, 3.9315, 3.7059, 3.9587, 3.0877, 4.5452,\n        3.5187, 4.5992, 3.4337, 5.2965, 2.0313, 1.4726, 1.0297, 4.0122, 5.9539,\n        3.0101, 2.3816, 4.7509, 5.9921, 2.2130, 3.2026, 4.1773, 3.3438, 4.2327,\n        3.5339, 4.1111, 1.2231, 4.1321, 4.1261, 4.2311, 2.7761, 5.3019, 3.8807,\n        3.0990, 2.6014, 1.8158, 4.2587, 2.8624, 2.6647, 4.4088, 4.7509, 4.4043,\n        5.7698, 1.6279, 4.2021, 3.3451, 2.0065, 3.6195, 4.7238, 2.1917, 5.2906,\n        4.9076, 3.5087, 3.5071, 2.9211, 3.1794, 3.4200, 1.9849, 3.9643, 2.0127,\n        3.5784, 5.8916, 5.9385, 4.6156, 1.6393, 2.6396, 3.7590, 4.0292, 3.7251,\n        4.6762, 2.2203, 3.6392, 2.5612, 2.8634, 1.7247, 2.5979, 4.4605, 3.6831,\n        3.5528, 3.1704, 4.1762, 4.3548, 4.2143, 4.2168, 5.5656, 3.7312, 5.4832,\n        2.9260, 3.5884, 3.2168, 3.8471, 4.7193, 3.4517, 2.9338, 2.0219, 3.6204,\n        2.5630], dtype=torch.float64)\n"
    }
   ],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "var_nums = 2\n",
    "sample_nums = 1000\n",
    "true_bias = 4.2\n",
    "X = torch.rand(sample_nums, var_nums)  # shape=[1000, 2]\n",
    "print('X.shape=', X.shape)\n",
    "true_weight = torch.tensor([2, -3.4]).view(2, 1)  # shape=[2, 1]\n",
    "print('true_weights.shape', true_weight.shape)\n",
    "error = torch.tensor(np.random.normal(0.0, 0.1, sample_nums))\n",
    "print('error.shape', error.shape)\n",
    "Y = torch.mm(X, true_weight).view(sample_nums) + true_bias + error\n",
    "print('Y.shape', Y.shape)\n",
    "print(Y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. 定义基本工具函数\n",
    "+ 模型函数，指定样本和权重，对应的输出\n",
    "+ 初始化参数函数，初始化所有必须的参数\n",
    "+ 损失函数，初始化衡量误差的函数\n",
    "+ 样本获取函数，如何批量获取样本"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义回归模型\n",
    "def linear_reg_model(weight, bias, input, batch_size):\n",
    "    return torch.mm(input, weight).view(batch_size) + bias\n",
    "\n",
    "def params_init():\n",
    "    p_weights = torch.randn(2, 1, dtype=torch.float32, requires_grad=True)\n",
    "    p_bias = torch.zeros(1, dtype=torch.float32, requires_grad=True)\n",
    "    return p_weights, p_bias\n",
    "\n",
    "def loss_func(true_Y, hat_Y):\n",
    "    error = true_Y - hat_Y\n",
    "    return error ** 2\n",
    "\n",
    "def sample_batchs(X, true_Y, sample_nums, batch_size):\n",
    "    res = []\n",
    "    inds = list(range(0, sample_nums))\n",
    "    np.random.shuffle(inds)\n",
    "    cur_ind = 0\n",
    "    while cur_ind + batch_size < sample_nums:\n",
    "        keep_inds = inds[cur_ind:cur_ind + batch_size]\n",
    "        res.append((torch.index_select(X, 0, torch.tensor(keep_inds, dtype=torch.int64)), \n",
    "            torch.index_select(true_Y, 0, torch.tensor(keep_inds, dtype=torch.int64))))\n",
    "        cur_ind += batch_size\n",
    "    keep_inds = inds[cur_ind:cur_ind + batch_size]\n",
    "    if keep_inds:\n",
    "        res.append((torch.index_select(X, 0, torch.tensor(keep_inds, dtype=torch.int64)), \n",
    "            torch.index_select(true_Y, 0, torch.tensor(keep_inds, dtype=torch.int64))))\n",
    "    return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "(tensor([[0.6280, 0.7622],\n        [0.4410, 0.2582],\n        [0.0578, 0.3930],\n        [0.2031, 0.2891],\n        [0.8529, 0.1454],\n        [0.7091, 0.5959],\n        [0.1866, 0.2087],\n        [0.6091, 0.6290],\n        [0.6470, 0.8776],\n        [0.6175, 0.3982],\n        [0.5556, 0.0076],\n        [0.4439, 0.1246],\n        [0.8183, 0.7475],\n        [0.8767, 0.8118],\n        [0.1617, 0.8198],\n        [0.6508, 0.3082],\n        [0.0043, 0.0019],\n        [0.9388, 0.0294],\n        [0.4366, 0.4452],\n        [0.3214, 0.3705],\n        [0.7145, 0.3275],\n        [0.4385, 0.5287],\n        [0.0708, 0.3658],\n        [0.9899, 0.2482],\n        [0.4600, 0.8385],\n        [0.4917, 0.0014],\n        [0.4817, 0.9047],\n        [0.6969, 0.1907],\n        [0.4987, 0.2497],\n        [0.1685, 0.7173],\n        [0.0289, 0.1620],\n        [0.7326, 0.0606],\n        [0.0928, 0.4620],\n        [0.4855, 0.3306],\n        [0.8491, 0.5779],\n        [0.3495, 0.8891],\n        [0.8898, 0.7307],\n        [0.6151, 0.9724],\n        [0.6910, 0.6853],\n        [0.5420, 0.5667],\n        [0.5551, 0.2120],\n        [0.9351, 0.8459],\n        [0.8611, 0.1986],\n        [0.0270, 0.7175],\n        [0.9641, 0.0662],\n        [0.8389, 0.8335],\n        [0.8923, 0.2658],\n        [0.8853, 0.9439],\n        [0.7339, 0.5074],\n        [0.9003, 0.3802],\n        [0.1049, 0.1238],\n        [0.1131, 0.2345],\n        [0.9150, 0.0942],\n        [0.1609, 0.9381],\n        [0.9456, 0.1964],\n        [0.2355, 0.1448],\n        [0.1950, 0.1956],\n        [0.8835, 0.6706],\n        [0.6235, 0.1018],\n        [0.5764, 0.9557],\n        [0.6632, 0.3311],\n        [0.6370, 0.9917],\n        [0.6304, 0.5103],\n        [0.4415, 0.8575],\n        [0.1466, 0.2306],\n        [0.5181, 0.4009],\n        [0.6544, 0.2948],\n        [0.6562, 0.6156],\n        [0.8825, 0.9111],\n        [0.5743, 0.3906],\n        [0.5876, 0.9150],\n        [0.9252, 0.1961],\n        [0.8340, 0.9238],\n        [0.0411, 0.9291],\n        [0.5805, 0.6771],\n        [0.6178, 0.4547],\n        [0.3392, 0.4900],\n        [0.0872, 0.1900],\n        [0.9877, 0.0382],\n        [0.7980, 0.8839],\n        [0.2023, 0.1805],\n        [0.2732, 0.0731],\n        [0.8967, 0.5701],\n        [0.3851, 0.3545],\n        [0.1511, 0.1982],\n        [0.8895, 0.6043],\n        [0.8333, 0.1464],\n        [0.1244, 0.8028],\n        [0.1767, 0.2675],\n        [0.3003, 0.5536],\n        [0.5765, 0.4266],\n        [0.3648, 0.1298],\n        [0.8082, 0.5981],\n        [0.9760, 0.5629],\n        [0.6857, 0.1889],\n        [0.1251, 0.3080],\n        [0.5873, 0.2586],\n        [0.5887, 0.7900],\n        [0.0930, 0.1074],\n        [0.0476, 0.8156]]), tensor([2.8906, 4.3362, 2.9646, 3.6462, 5.4315, 3.5991, 3.9805, 2.9825, 2.5470,\n        3.9670, 5.2531, 4.6165, 3.2214, 3.2924, 1.7778, 4.3053, 4.1312, 6.0847,\n        3.6037, 3.6941, 4.3623, 3.1994, 3.3199, 5.4695, 2.1184, 5.2630, 2.2990,\n        4.9043, 4.2856, 2.1789, 3.5854, 5.6168, 2.9551, 4.2172, 3.9352, 2.0182,\n        3.5040, 2.1005, 3.1112, 3.3202, 4.5746, 3.2676, 5.2660, 1.9539, 5.8306,\n        3.0599, 5.0031, 2.8719, 3.7828, 4.7218, 3.8973, 3.4195, 5.8308, 1.2990,\n        5.4681, 4.2088, 3.9665, 3.7652, 5.2025, 2.0962, 4.2998, 2.0345, 3.7822,\n        2.2527, 3.5661, 3.7562, 4.3604, 3.4552, 2.8847, 4.2253, 2.1551, 5.4506,\n        2.7631, 1.1871, 2.9271, 3.8445, 3.0862, 3.7643, 6.0950, 2.7255, 3.9656,\n        4.6060, 4.1290, 3.8249, 3.7730, 3.8856, 5.3338, 1.6296, 3.6672, 2.9325,\n        3.9316, 4.4571, 3.7545, 4.1802, 5.0757, 3.4068, 4.5020, 2.8203, 4.0444,\n        1.7128]))\n"
    }
   ],
   "source": [
    "res = sample_batchs(X, Y, 1000, 100)\n",
    "print(res[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. 开始模型训练\n",
    "采用微批梯度下降法，需要指定以下参数：\n",
    "+ epoch_nums: 整个样本迭代训练几次\n",
    "+ batch_nums: 每次微批的数据量大小\n",
    "+ X: 输入的数据 X\n",
    "+ ture_Y: 样本真正的 Y\n",
    "+ step_ratio: 迭代步长"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "tensor([[1.8841],\n        [0.9425]], requires_grad=True)\ntensor([0.], requires_grad=True)\nepoch=1\nepoch=1, loss=0.6970049738883972\ntrue_weight=tensor([[ 2.0000],\n        [-3.4000]]), p_weight=tensor([[ 2.5768],\n        [-1.1153]], requires_grad=True)\ntrue_bias=4.2, p_bias=tensor([2.6210], requires_grad=True)\nepoch=2\nepoch=2, loss=0.3500586450099945\ntrue_weight=tensor([[ 2.0000],\n        [-3.4000]]), p_weight=tensor([[ 2.4202],\n        [-2.3266]], requires_grad=True)\ntrue_bias=4.2, p_bias=tensor([3.3870], requires_grad=True)\nepoch=3\nepoch=3, loss=0.19472648203372955\ntrue_weight=tensor([[ 2.0000],\n        [-3.4000]]), p_weight=tensor([[ 2.2607],\n        [-2.8783]], requires_grad=True)\ntrue_bias=4.2, p_bias=tensor([3.7845], requires_grad=True)\nepoch=4\nepoch=4, loss=0.13141153752803802\ntrue_weight=tensor([[ 2.0000],\n        [-3.4000]]), p_weight=tensor([[ 2.1431],\n        [-3.1482]], requires_grad=True)\ntrue_bias=4.2, p_bias=tensor([3.9715], requires_grad=True)\nepoch=5\nepoch=5, loss=0.10819939523935318\ntrue_weight=tensor([[ 2.0000],\n        [-3.4000]]), p_weight=tensor([[ 2.0779],\n        [-3.2663]], requires_grad=True)\ntrue_bias=4.2, p_bias=tensor([4.0841], requires_grad=True)\nepoch=6\nepoch=6, loss=0.10235301405191422\ntrue_weight=tensor([[ 2.0000],\n        [-3.4000]]), p_weight=tensor([[ 2.0365],\n        [-3.3299]], requires_grad=True)\ntrue_bias=4.2, p_bias=tensor([4.1309], requires_grad=True)\nepoch=7\nepoch=7, loss=0.10016227513551712\ntrue_weight=tensor([[ 2.0000],\n        [-3.4000]]), p_weight=tensor([[ 2.0232],\n        [-3.3520]], requires_grad=True)\ntrue_bias=4.2, p_bias=tensor([4.1696], requires_grad=True)\nepoch=8\nepoch=8, loss=0.09938601404428482\ntrue_weight=tensor([[ 2.0000],\n        [-3.4000]]), p_weight=tensor([[ 2.0085],\n        [-3.3718]], requires_grad=True)\ntrue_bias=4.2, p_bias=tensor([4.1782], requires_grad=True)\nepoch=9\nepoch=9, loss=0.09924861788749695\ntrue_weight=tensor([[ 2.0000],\n        [-3.4000]]), p_weight=tensor([[ 2.0047],\n        [-3.3809]], requires_grad=True)\ntrue_bias=4.2, p_bias=tensor([4.1856], requires_grad=True)\nepoch=10\nepoch=10, loss=0.10020212084054947\ntrue_weight=tensor([[ 2.0000],\n        [-3.4000]]), p_weight=tensor([[ 1.9944],\n        [-3.3869]], requires_grad=True)\ntrue_bias=4.2, p_bias=tensor([4.1819], requires_grad=True)\nepoch=11\nepoch=11, loss=0.09921132773160934\ntrue_weight=tensor([[ 2.0000],\n        [-3.4000]]), p_weight=tensor([[ 1.9977],\n        [-3.3835]], requires_grad=True)\ntrue_bias=4.2, p_bias=tensor([4.1958], requires_grad=True)\nepoch=12\nepoch=12, loss=0.0999988317489624\ntrue_weight=tensor([[ 2.0000],\n        [-3.4000]]), p_weight=tensor([[ 1.9923],\n        [-3.3904]], requires_grad=True)\ntrue_bias=4.2, p_bias=tensor([4.1863], requires_grad=True)\nepoch=13\nepoch=13, loss=0.09925936162471771\ntrue_weight=tensor([[ 2.0000],\n        [-3.4000]]), p_weight=tensor([[ 1.9949],\n        [-3.3878]], requires_grad=True)\ntrue_bias=4.2, p_bias=tensor([4.1922], requires_grad=True)\nepoch=14\nepoch=14, loss=0.09929240494966507\ntrue_weight=tensor([[ 2.0000],\n        [-3.4000]]), p_weight=tensor([[ 1.9949],\n        [-3.3867]], requires_grad=True)\ntrue_bias=4.2, p_bias=tensor([4.1909], requires_grad=True)\nepoch=15\nepoch=15, loss=0.09920921176671982\ntrue_weight=tensor([[ 2.0000],\n        [-3.4000]]), p_weight=tensor([[ 1.9944],\n        [-3.3848]], requires_grad=True)\ntrue_bias=4.2, p_bias=tensor([4.1923], requires_grad=True)\nepoch=16\nepoch=16, loss=0.09928484261035919\ntrue_weight=tensor([[ 2.0000],\n        [-3.4000]]), p_weight=tensor([[ 1.9947],\n        [-3.3846]], requires_grad=True)\ntrue_bias=4.2, p_bias=tensor([4.1901], requires_grad=True)\nepoch=17\nepoch=17, loss=0.10140980035066605\ntrue_weight=tensor([[ 2.0000],\n        [-3.4000]]), p_weight=tensor([[ 1.9884],\n        [-3.3908]], requires_grad=True)\ntrue_bias=4.2, p_bias=tensor([4.1802], requires_grad=True)\nepoch=18\nepoch=18, loss=0.10026668757200241\ntrue_weight=tensor([[ 2.0000],\n        [-3.4000]]), p_weight=tensor([[ 1.9979],\n        [-3.3770]], requires_grad=True)\ntrue_bias=4.2, p_bias=tensor([4.2040], requires_grad=True)\nepoch=19\nepoch=19, loss=0.09923431277275085\ntrue_weight=tensor([[ 2.0000],\n        [-3.4000]]), p_weight=tensor([[ 1.9953],\n        [-3.3822]], requires_grad=True)\ntrue_bias=4.2, p_bias=tensor([4.1970], requires_grad=True)\nepoch=20\nepoch=20, loss=0.09920695424079895\ntrue_weight=tensor([[ 2.0000],\n        [-3.4000]]), p_weight=tensor([[ 1.9920],\n        [-3.3840]], requires_grad=True)\ntrue_bias=4.2, p_bias=tensor([4.1933], requires_grad=True)\n"
    }
   ],
   "source": [
    "epoch_nums = 20\n",
    "batch_nums = 5\n",
    "step_ratio = 0.03\n",
    "X = X\n",
    "true_Y = Y\n",
    "p_weight, p_bias = params_init()\n",
    "print(p_weight)\n",
    "print(p_bias)\n",
    "\n",
    "for epoch in range(0, epoch_nums):\n",
    "    print(\"epoch={}\".format(epoch+1))\n",
    "    batchs = sample_batchs(X, true_Y, sample_nums, batch_nums)\n",
    "    for batch in batchs:\n",
    "        b_X, b_true_Y = batch\n",
    "        batch_size = b_true_Y.shape[0]\n",
    "        b_hat_Y = linear_reg_model(p_weight, p_bias, b_X, batch_size)\n",
    "        loss = loss_func(b_true_Y, b_hat_Y).sum()\n",
    "\n",
    "        loss.backward()\n",
    "        p_weight.data -= step_ratio*p_weight.grad/batch_size\n",
    "        p_bias.data -= step_ratio*p_bias.grad/batch_size\n",
    "\n",
    "        p_weight.grad.data.zero_()\n",
    "        p_bias.grad.data.zero_()\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        hat_Y = linear_reg_model(p_weight, p_bias, X, sample_nums)\n",
    "        loss = loss_func(true_Y, hat_Y)\n",
    "        print(\"epoch={}, loss={}\".format(epoch+1, torch.sqrt(loss.mean())))\n",
    "        print(\"true_weight={}, p_weight={}\".format(true_weight, p_weight))\n",
    "        print(\"true_bias={}, p_bias={}\".format(true_bias, p_bias))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. 线性回归高级版本实现\n",
    "通过 PyTorch Lightning 来实现， https://github.com/PyTorchLightning/pytorch-lightning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import nessasary lib\n",
    "import os\n",
    "import torch\n",
    "from torch.nn import functional as F\n",
    "from torch.utils.data import DataLoader\n",
    "from torch.utils.data import TensorDataset\n",
    "import pytorch_lightning as pl\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LineRegModel(pl.LightningModule):\n",
    "\n",
    "    def __init__(self):\n",
    "        super(LineRegModel, self).__init__()\n",
    "        self.true_weight = torch.tensor([2, -3.4], dtype=torch.float32).view(2, 1)\n",
    "        self.true_bias = torch.tensor(4.2, dtype=torch.float32)\n",
    "        # 定义模型结构\n",
    "        self.l1 = torch.nn.Linear(2, 1)\n",
    "\n",
    "    def forward(self, x):\n",
    "        # 必须：定义模型\n",
    "        return self.l1(x)\n",
    "\n",
    "    def training_step(self, batch, batch_nb):\n",
    "        # 必须提供：定于训练过程\n",
    "        x, y = batch\n",
    "        y_hat = self(x)\n",
    "        loss = F.mse_loss(y_hat, y)\n",
    "        tensorboard_logs = {'train_loss': loss}\n",
    "        return {'loss': loss, 'log': tensorboard_logs}\n",
    "\n",
    "    def validation_step(self, batch, batch_nb):\n",
    "        # 可选提供：定义验证过程\n",
    "        x, y = batch\n",
    "        y_hat = self(x)\n",
    "        \n",
    "        return {'val_loss': F.mse_loss(y_hat, y)}\n",
    "\n",
    "    def validation_epoch_end(self, outputs):\n",
    "        # 可选提供：定义验证过程\n",
    "        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()\n",
    "        tensorboard_logs = {'val_loss': avg_loss}\n",
    "        return {'val_loss': avg_loss, 'log': tensorboard_logs}\n",
    "\n",
    "    def test_step(self, batch, batch_nb):\n",
    "        # 可选提供：定义测试过程\n",
    "        x, y = batch\n",
    "        y_hat = self(x)\n",
    "        return {'test_loss': F.mse_loss(y_hat, y)}\n",
    "\n",
    "    def test_epoch_end(self, outputs):\n",
    "        # 可选提供：定义测试过程\n",
    "        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()\n",
    "        logs = {'test_loss': avg_loss}\n",
    "        return {'test_loss': avg_loss, 'log': logs, 'progress_bar': logs}\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        # 必须提供：定义优化器\n",
    "        # can return multiple optimizers and learning_rate schedulers\n",
    "        # (LBFGS it is automatically supported, no need for closure function)\n",
    "        return torch.optim.SGD(self.parameters(), lr=0.04)\n",
    "\n",
    "    def gen_data_loader(self, shuffle, sample_nums, batch_size):\n",
    "        X = torch.rand(sample_nums, 2, dtype=torch.float32) \n",
    "        error = torch.tensor(np.random.normal(0.0, 0.1, sample_nums), dtype=torch.float32).view(sample_nums, 1)\n",
    "        Y = torch.mm(X, self.true_weight) + self.true_bias + error\n",
    "        print(\"True Y Shape = {}\".format(Y.shape))\n",
    "        # 先转换成 torch 能识别的 Dataset\n",
    "        torch_dataset = TensorDataset(X, Y)\n",
    "\n",
    "        # 把 dataset 放入 DataLoader\n",
    "        loader = DataLoader(\n",
    "            dataset=torch_dataset,      # torch TensorDataset format\n",
    "            batch_size=batch_size,      # mini batch size\n",
    "            shuffle=shuffle,            # 要不要打乱数据 (打乱比较好)\n",
    "            num_workers=4,              # 多线程来读数据\n",
    "        )\n",
    "        return loader\n",
    "\n",
    "    def train_dataloader(self):\n",
    "        # 必须提供：提供训练数据集\n",
    "        return self.gen_data_loader(True, 1000, 20)\n",
    "\n",
    "    def val_dataloader(self):\n",
    "        # 可选提供：提供验证数据集\n",
    "        return self.gen_data_loader(False, 1000, 1000)\n",
    "\n",
    "    def test_dataloader(self):\n",
    "        # 可选提供：提供测试数据集\n",
    "        return self.gen_data_loader(False, 1000, 1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": "GPU available: False, used: False\nTPU available: False, using: 0 TPU cores\n\n  | Name | Type   | Params\n--------------------------------\n0 | l1   | Linear | 3     \nTrue Y Shape = torch.Size([1000, 1])\nTrue Y Shape = torch.Size([1000, 1])\n"
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "16f8d6a1dbaa4ac8a386b107456c9ce4"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "844832ac819540e89b7d2eba0efc3413"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "64d3fb23a4e8402f812c387f15e3d5ed"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "8617654ec5314d7c8a0a587f1d4402cc"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "9fb9cc0373f74f44b9cd972bb4804e5b"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "036764f6938448fc8f9981ff379c5396"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "8d14509589554f51bbf1b9864dc5c301"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "c8700c7a1df24ab8b0ad2b3034bf36ff"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "d185d4953b9c48e681f5ff1b6b07798d"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "22d2c4a99882473383cdddeb022acf51"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "1ebe27d311644f2b82f181f8326291e6"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "f873684050754031806fda795d96ce42"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "c9d8b22b280748bfbcf976694d0bf41f"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "a83be47856a14a718220b3dc5df2f8e0"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "81b1b56a731a41e6aa8266685fdfdc16"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "96731cbf643c4fd39fc2eb8c43a53b67"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "649f8ab56718498b9e6beb0e8605114d"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "1fefa00f30d04c54a78371505f48d981"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "d39d88298c5d4fa3832b4be0f93690f3"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "188bdcfb80f74d1fa44174bd4335fcaa"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "8d1ec470e1654fa391a23c0897c25007"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "\n"
    },
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": "1"
     },
     "metadata": {},
     "execution_count": 38
    }
   ],
   "source": [
    "lr_model = LineRegModel()\n",
    "\n",
    "# most basic trainer, uses good defaults (1 gpu)\n",
    "trainer = pl.Trainer(max_epochs=20, num_sanity_val_steps=0)\n",
    "trainer.fit(lr_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "Parameter containing:\ntensor([[ 2.0270, -3.3655]], requires_grad=True)\nParameter containing:\ntensor([4.1701], requires_grad=True)\n"
    }
   ],
   "source": [
    "# 打印所有参数\n",
    "for i in lr_model.parameters():\n",
    "    print(i)\n",
    "\n",
    "true_weight = torch.tensor([2, -3.4], dtype=torch.float32).view(2, 1)\n",
    "true_bias = torch.tensor(4.2, dtype=torch.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "True Y Shape = torch.Size([1000, 1])\n"
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "8ca587c8c76947d396b9ca79a55a3571"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "--------------------------------------------------------------------------------\nTEST RESULTS\n{'test_loss': tensor(0.0112)}\n--------------------------------------------------------------------------------\n\n"
    },
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": "{'test_loss': 0.0112014040350914}"
     },
     "metadata": {},
     "execution_count": 40
    }
   ],
   "source": [
    "trainer.test()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Pytorch Lightning\n",
    "lightling 是对 pytorch 的进一步封装，用户不需要关心周边工作，只需要集中精力关心模型定义，优化器的定义，准备训练集、测试集、验证集，其他的事情都交给 lightling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ]
}